%matplotlib inline
import os
os.chdir("../")
import random
from time import time
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from portraitseg.utils import plots, denormalizer, show_input_output_target
from portraitseg.pytorch_datasets import FlickrPortraitMaskDataset
from portraitseg.pytorch_dataloaders import get_train_valid_loader
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
def forward(self, x):
return self.conv1(x)
SEED = 0
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
DATA_DIR = "../data/"
FLICKR_DIR = DATA_DIR + "portraits/flickr/"
TESTING_PIPES = True
# Hyperparameters
BATCH_SIZE = 1
LR = 1e-2
NB_EPOCHS = 2
AUGMENT = False
net = Net().cuda()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=LR)
trn_loader, val_loader = get_train_valid_loader(FLICKR_DIR,
batch_size=BATCH_SIZE,
augment=AUGMENT,
random_seed=SEED,
valid_size=0.2,
show_sample=True,
num_workers=1,
pin_memory=True)
val_outputs = []
for epoch in range(NB_EPOCHS):
start = time()
running_loss = 0.0
print("\n[Epoch, batches]")
for i, sample_batch in enumerate(trn_loader, 0):
portraits, masks = sample_batch
portraits, masks = Variable(portraits).cuda(), Variable(masks).cuda()
outputs = net(portraits)
loss = loss_fn(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.data[0]
if i % 100 == 99:
print("[%d, %4d] loss: %.3f" % (epoch+1, i+1, running_loss/100))
running_loss = 0.0
if TESTING_PIPES:
break
# Test on validation set
running_loss_val = 0.0
for i, sample_batch in enumerate(val_loader, 0):
portraits, masks = sample_batch
portraits, masks = Variable(portraits).cuda(), Variable(masks).cuda()
outputs = net(portraits)
loss = loss_fn(outputs, masks)
running_loss_val += loss.data[0]
if i == 1 and epoch == 0:
portraits_v = portraits
masks_v = masks
print("Validation loss: %.3f" % (running_loss_val/len(val_loader)))
print("Epoch duration: %.2f seconds" % (time() - start))
val_outputs.append(net(portraits_v))
show_input_output_target(portraits_v,
val_outputs,
masks_v,
denormalizer)
print("Training complete.")